{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 实现标准简单线性回归\n",
    "\n",
    "具体公式不谈，主要都是代码"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from math import sqrt\n",
    "dataset = [[1.2, 1.1], [2.4, 3.5], [4.1, 3.2], [3.4, 2.8], [5, 5.4]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 计算样本均值函数\n",
    "def mean(values):\n",
    "    return sum(values) / float(len(values))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 计算 x 与 y 协方差的函数\n",
    "def covariance(x, mean_x, y, mean_y):\n",
    "    covar = 0.0\n",
    "    for i in range(len(x)):\n",
    "        covar += (x[i] - mean_x) * (y[i] - mean_y)\n",
    "    return covar"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 计算方差的函数\n",
    "def variance(values, mean):\n",
    "    return sum([(x - mean)**2 for x in values])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 计算回归系数的函数\n",
    "def coefficients(dataset):\n",
    "    x = [row[0] for row in dataset]\n",
    "    y = [row[1] for row in dataset]\n",
    "    x_mean, y_mean = mean(x), mean(y)\n",
    "    w1 = covariance(x, x_mean, y, y_mean) / variance(x, x_mean)\n",
    "    w0 = y_mean - w1*x_mean\n",
    "    return (w0, w1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 计算均方根误差 RMSE 的函数\n",
    "def rmse_metric(actual, predicted):\n",
    "    sum_error = 0.0\n",
    "    for i in range(len(actual)):\n",
    "        prediction_error = predicted[i] - actual[i]\n",
    "        sum_error += (prediction_error ** 2)\n",
    "    mean_error = sum_error / float(len(actual))\n",
    "    return sqrt(mean_error)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 构建简单线性回归\n",
    "def simple_linear_regression(train, test):\n",
    "    predictions = list()\n",
    "    w0, w1 = coefficients(train) # 这里实际上是用最小二乘法求得系数和偏置，因为这个是一个公式推导的结果，所以不需要去迭代优化\n",
    "    for row in test:\n",
    "        y_model = w1 * row[0] + w0\n",
    "        predictions.append(y_model)\n",
    "    return predictions\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 评估算法的函数\n",
    "def evaluate_algorithm(dataset, algorithm):\n",
    "    test_set = list()\n",
    "    for row in dataset:\n",
    "        row_copy = list(row)\n",
    "        row_copy[-1] = None\n",
    "        test_set.append(row_copy)\n",
    "    predicted = algorithm(dataset, test_set)\n",
    "    for val in predicted:\n",
    "        print('%.3f\\t' % val)\n",
    "        \n",
    "    actual = [row[-1] for row in dataset]\n",
    "    rmse = rmse_metric(actual, predicted)\n",
    "    return rmse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.386\t\n",
      "2.463\t\n",
      "3.990\t\n",
      "3.362\t\n",
      "4.799\t\n",
      "RMSE: 0.701\n"
     ]
    }
   ],
   "source": [
    "# 主函数运行\n",
    "rmse = evaluate_algorithm(dataset, simple_linear_regression)\n",
    "print('RMSE: %.3f' % rmse)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
